# -*- coding:utf-8 -*-
# @time :2021.6.10
# @IDE : pycharm
# @author :likecy
# @github : https://gitee.com/likecy

from torch.hub import load_state_dict_from_url
import torch.nn as nn
import torch
import torchvision
from models import resnet101, densenet121, densenet169, resnet18, resnet50, resnet152, mobilenet_v2
from models import resnext101_32x8d_wsl, resnext101_32x16d_wsl, resnext101_32x32d_wsl, resnext101_32x48d_wsl
from models import squeezenet1_0, squeezenet1_1
from models import shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0
from models import EfficientNet
from models import googlenet, alexnet
import torch.utils.model_zoo as model_zoo

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth',
    'resnext101_32x16d': 'https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth',
    'resnext101_32x32d': 'https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth',
    'resnext101_32x48d': 'https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth',
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'moblienetv2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
    'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
    'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
    'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
    'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
    'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
    'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
    'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
    'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
}


def load_model(model, pretrained_state_dict):
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_state_dict.items() if
                       k in model_dict and model_dict[k].size() == pretrained_state_dict[k].size()}
    model.load_state_dict(pretrained_dict, strict=False)
    if len(pretrained_dict) == 0:
        print("[INFO] No params were loaded ...")
    else:
        for k, v in pretrained_state_dict.items():
            if k in pretrained_dict:
                print("==>> Load {} {}".format(k, v.size()))
            else:
                print("[INFO] Skip {} {}".format(k, v.size()))
    return model


# ==========alexnet==============
def Alexnet(num_classes, pretrained=False):
    model = alexnet(pretrained=pretrained, num_classes=num_classes)
    return model


# ==========googlenet==============
def Googlenet(num_classes, pretrained=False):
    model = googlenet(pretrained=pretrained)
    # if pretrained:
    #     # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    #     print("Load pretrained model from {}".format(model_urls['shufflenet_v2_x0_5']))
    #     pretrained_state_dict = model_zoo.load_url(model_urls['shufflenet_v2_x0_5'])
    #     model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model

# ==========resnet==============


def Resnet18(num_classes, pretrained=False):
    model = resnet18()
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
        print("Load pretrained model from {}".format(model_urls['resnet18']))
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
        model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model


def Resnet50(num_classes, pretrained=False):
    model = resnet50()
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
        print("Load pretrained model from {}".format(model_urls['resnet50']))
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
        model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model


def Resnet101(num_classes, pretrained=False):
    model = resnet101()
    # if not test:
    #     if LOCAL_PRETRAINED['resnet101'] == None:
    #         state_dict = load_state_dict_from_url(model_urls['resnet101'], progress=True)
    #     else:
    #         state_dict = state_dict = torch.load(LOCAL_PRETRAINED['resnet101'])
    #     model.load_state_dict(state_dict)
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
        print("Load pretrained model from {}".format(model_urls['resnet101']))
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
        model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model


def Resnet152(num_classes, pretrained=False):
    model = resnet152()
    # if not test:
    #     if LOCAL_PRETRAINED['resnet101'] == None:
    #         state_dict = load_state_dict_from_url(model_urls['resnet101'], progress=True)
    #     else:
    #         state_dict = state_dict = torch.load(LOCAL_PRETRAINED['resnet101'])
    #     model.load_state_dict(state_dict)
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
        print("Load pretrained model from {}".format(model_urls['resnet152']))
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
        model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model


def Resnext101_32x8d(num_classes, pretrained=False):
    model = resnext101_32x8d_wsl()
    # if not test:
    #     if LOCAL_PRETRAINED['resnext101_32x8d'] == None:
    #         state_dict = load_state_dict_from_url(model_urls['resnext101_32x8d'], progress=True)
    #     else:
    #         state_dict = state_dict = torch.load(LOCAL_PRETRAINED['resnext101_32x8d'])
    #     model.load_state_dict(state_dict)
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
        print("Load pretrained model from {}".format(
            model_urls['resnext101_32x8d']))
        pretrained_state_dict = model_zoo.load_url(
            model_urls['resnext101_32x8d'])
        model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model


def Resnext101_32x16d(num_classes, pretrained=False):
    model = resnext101_32x16d_wsl()
    if pretrained:
        print("Load pretrained model from {}".format(
            model_urls['resnext101_32x16d']))
        pretrained_state_dict = model_zoo.load_url(
            model_urls['resnext101_32x16d'])
        model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model


def Resnext101_32x32d(num_classes, pretrained=False):
    model = resnext101_32x32d_wsl()
    # if not test:
    #     if LOCAL_PRETRAINED['resnext101_32x32d'] == None:
    #         state_dict = load_state_dict_from_url(model_urls['resnext101_32x32d'], progress=True)
    #     else:
    #         state_dict = state_dict = torch.load(LOCAL_PRETRAINED['resnext101_32x32d'])
    #     model.load_state_dict(state_dict)
    if pretrained:
        print("Load pretrained model from {}".format(
            model_urls['resnext101_32x32d']))
        pretrained_state_dict = model_zoo.load_url(
            model_urls['resnext101_32x32d'])
        model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model


def Resnext101_32x48d(num_classes, pretrained=False):
    model = resnext101_32x48d_wsl()
    # if not test:
    #     if LOCAL_PRETRAINED['resnext101_32x48d'] == None:
    #         state_dict = load_state_dict_from_url(model_urls['resnext101_32x48d'], progress=True)
    #     else:
    #         state_dict = state_dict = torch.load(LOCAL_PRETRAINED['resnext101_32x48d'])
    #     model.load_state_dict(state_dict)
    if pretrained:
        print("Load pretrained model from {}".format(
            model_urls['resnext101_32x48d']))
        pretrained_state_dict = model_zoo.load_url(
            model_urls['resnext101_32x48d'])
        model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model

# ==========Densenet==============


def Densenet121(num_classes, pretrained=False):
    model = densenet121()
    if pretrained:
        state_dict = load_state_dict_from_url(
            model_urls['densenet121'], progress=True)
        # if LOCAL_PRETRAINED['densenet121'] == None:
        #     state_dict = load_state_dict_from_url(model_urls['densenet121'], progress=True)
        # else:
        #     state_dict = state_dict = torch.load(LOCAL_PRETRAINED['densenet121'])

        from collections import OrderedDict
        new_state_dict = OrderedDict()

        for k, v in state_dict.items():
            # print(k)  #打印预训练模型的键，发现与网络定义的键有一定的差别，因而需要将键值进行对应的更改，将键值分别对应打印出来就可以看出不同，根据不同进行修改
            # torchvision中的网络定义，采用了正则表达式，来更改键值，因为这里简单，没有再去构建正则表达式
            # 直接利用if语句筛选不一致的键
            # 修正键值的不对应
            if k.split('.')[0] == 'features' and (len(k.split('.'))) > 4:
                k = k.split('.')[0]+'.'+k.split('.')[1]+'.'+k.split('.')[2] + \
                    '.'+k.split('.')[-3] + k.split('.')[-2] + \
                    '.'+k.split('.')[-1]
            # print(k)
            else:
                pass
            new_state_dict[k] = v
        model.load_state_dict(new_state_dict)
    fc_features = model.classifier.in_features
    model.classifier = nn.Linear(fc_features, num_classes)
    return model


def Densenet169(num_classes, pretrained=False):
    model = densenet169()
    if pretrained:
        state_dict = load_state_dict_from_url(
            model_urls['densenet169'], progress=True)
        # if LOCAL_PRETRAINED['densenet169'] == None:
        #     state_dict = load_state_dict_from_url(model_urls['densenet169'], progress=True)
        # else:
        #     state_dict = state_dict = torch.load(LOCAL_PRETRAINED['densenet169'])

        from collections import OrderedDict
        new_state_dict = OrderedDict()

        for k, v in state_dict.items():
            # print(k)  #打印预训练模型的键，发现与网络定义的键有一定的差别，因而需要将键值进行对应的更改，将键值分别对应打印出来就可以看出不同，根据不同进行修改
            # torchvision中的网络定义，采用了正则表达式，来更改键值，因为这里简单，没有再去构建正则表达式
            # 直接利用if语句筛选不一致的键
            # 修正键值的不对应
            if k.split('.')[0] == 'features' and (len(k.split('.'))) > 4:
                k = k.split('.')[0]+'.'+k.split('.')[1]+'.'+k.split('.')[2] + \
                    '.'+k.split('.')[-3] + k.split('.')[-2] + \
                    '.'+k.split('.')[-1]
            # print(k)
            else:
                pass
            new_state_dict[k] = v
        model.load_state_dict(new_state_dict)
    fc_features = model.classifier.in_features
    model.classifier = nn.Linear(fc_features, num_classes)
    return model

# ==========Mobilenet==============


def Mobilenetv2(num_classes, pretrained=False):
    model = mobilenet_v2()
    if pretrained:
        print("Load pretrained model from {}".format(
            model_urls['moblienetv2']))
        pretrained_state_dict = model_zoo.load_url(model_urls['moblienetv2'])
        model = load_model(model, pretrained_state_dict)
    print(model.state_dict().keys())
    fc_features = model.classifier[1].in_features
    model.classifier = nn.Linear(fc_features, num_classes)
    return model

# ==========Efficientnet===========


def Efficientnet(model_name, num_classes, pretrained=False):
    '''
    model_name :'efficientnet-b0', 'efficientnet-b1-7'
    '''
    model = EfficientNet.from_name(model_name)
    # if pretrained:
    # state_dict = torch.load(LOCAL_PRETRAINED[model_name])
    # model.load_state_dict(state_dict)
    # print(model.state_dict())
    if pretrained:
        print("Load pretrained model from {}".format(model_urls[model_name]))
        pretrained_state_dict = model_zoo.load_url(model_urls[model_name])
        model = load_model(model, pretrained_state_dict)
    fc_features = model._fc.in_features
    model._fc = nn.Linear(fc_features, num_classes)
    return model


# ==========Squeezenet=============
def Squeezenet1_0(num_classes, pretrained=False):
    model = squeezenet1_0(pretrained=pretrained, num_classes=num_classes)
    # if pretrained:
    #     # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    #     print("Load pretrained model from {}".format(model_urls['squeezenet1_0']))
    #     pretrained_state_dict = model_zoo.load_url(model_urls['squeezenet1_0'])
    #     model = load_model(model, pretrained_state_dict)
    # fc_features = model.fc.in_features
    # model.fc = nn.Linear(fc_features, num_classes)
    return model


def Squeezenet1_1(num_classes, pretrained=False):
    model = squeezenet1_1(pretrained=pretrained, num_classes=num_classes)
    # if pretrained:
    #     # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    #     print("Load pretrained model from {}".format(model_urls['squeezenet1_1']))
    #     pretrained_state_dict = model_zoo.load_url(model_urls['squeezenet1_1'])
    #     model = load_model(model, pretrained_state_dict)
    # print(model)
    # fc_features = model.fc.in_features
    # model.fc = nn.Linear(fc_features, num_classes)
    return model


# ==========Shufflenet=============
def Shufflenet_v2_x0_5(num_classes, pretrained=False):
    model = shufflenet_v2_x0_5(pretrained=pretrained)
    # if pretrained:
    #     # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    #     print("Load pretrained model from {}".format(model_urls['shufflenet_v2_x0_5']))
    #     pretrained_state_dict = model_zoo.load_url(model_urls['shufflenet_v2_x0_5'])
    #     model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model


def Shufflenet_v2_x1_0(num_classes, pretrained=False):
    model = shufflenet_v2_x1_0(pretrained=pretrained)
    # if pretrained:
    #     # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    #     print("Load pretrained model from {}".format(model_urls['shufflenet_v2_x0_5']))
    #     pretrained_state_dict = model_zoo.load_url(model_urls['shufflenet_v2_x0_5'])
    #     model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model


def Shufflenet_v2_x1_5(num_classes, pretrained=False):
    model = shufflenet_v2_x1_5(pretrained=pretrained)
    # if pretrained:
    #     # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    #     print("Load pretrained model from {}".format(model_urls['shufflenet_v2_x0_5']))
    #     pretrained_state_dict = model_zoo.load_url(model_urls['shufflenet_v2_x0_5'])
    #     model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model


def Shufflenet_v2_x2_0(num_classes, pretrained=False):
    model = shufflenet_v2_x2_0(pretrained=pretrained)
    # if pretrained:
    #     # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    #     print("Load pretrained model from {}".format(model_urls['shufflenet_v2_x0_5']))
    #     pretrained_state_dict = model_zoo.load_url(model_urls['shufflenet_v2_x0_5'])
    #     model = load_model(model, pretrained_state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model
